{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import sys\n",
    "import re\n",
    "from subprocess import call\n",
    "import numpy as np\n",
    "from nltk import TweetTokenizer\n",
    "from nltk.tokenize.stanford import StanfordTokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Downloading the models\n",
    "\n",
    "As mentioned in the readme, here are the pretrained models you can download:\n",
    "\n",
    "- [sent2vec_wiki_unigrams](https://drive.google.com/open?id=0B6VhzidiLvjSa19uYWlLUEkzX3c) 5GB (600dim, trained on english wikipedia)\n",
    "- [sent2vec_wiki_bigrams](https://drive.google.com/open?id=0B6VhzidiLvjSaER5YkJUdWdPWU0) 16GB (700dim, trained on english wikipedia)\n",
    "- [sent2vec_twitter_unigrams](https://drive.google.com/open?id=0B6VhzidiLvjSaVFLM0xJNk9DTzg) 13GB (700dim, trained on english tweets)\n",
    "- [sent2vec_twitter_bigrams](https://drive.google.com/open?id=0B6VhzidiLvjSeHI4cmdQdXpTRHc) 23GB (700dim, trained on english tweets)\n",
    "- [sent2vec_toronto books_unigrams](https://drive.google.com/open?id=0B6VhzidiLvjSOWdGM0tOX1lUNEk) 2GB (700dim, trained on the [BookCorpus dataset](http://yknzhu.wixsite.com/mbweb))\n",
    "- [sent2vec_toronto books_bigrams](https://drive.google.com/open?id=0B6VhzidiLvjSdENLSEhrdWprQ0k) 7GB (700dim, trained on the [BookCorpus dataset](http://yknzhu.wixsite.com/mbweb))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "From here, one simple way to get sentence embeddings is to use the `print-sentence-vectors` command as shown in the README.  To properly use our models you ideally need to use the same preprocessing used during training. We provide here some simple code wrapping around the `print-sentence-vectors` command and handling the tokenization to match our models properly."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Linking things together\n",
    "\n",
    "In order to use the Stanford NLP tokenizer with NLTK, you need to get the `stanford-postagger.jar` available in the [CoreNLP library package](http://stanfordnlp.github.io/CoreNLP/).\n",
    "\n",
    "You can then proceed to link things by modifying the paths in the following cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "FASTTEXT_EXEC_PATH = os.path.abspath(\"./fasttext\")\n",
    "\n",
    "BASE_SNLP_PATH = \"/home/path/to/stanford_NLP/stanford-postagger-2016-10-31/\"\n",
    "SNLP_TAGGER_JAR = os.path.join(BASE_SNLP_PATH, \"stanford-postagger.jar\")\n",
    "\n",
    "MODEL_WIKI_UNIGRAMS = os.path.abspath(\"./wiki_unigrams.bin\")\n",
    "MODEL_WIKI_BIGRAMS = os.path.abspath(\"./wiki_bigrams.bin\")\n",
    "MODEL_TWITTER_UNIGRAMS = os.path.abspath('./twitter_unigrams.bin')\n",
    "MODEL_TWITTER_BIGRAMS = os.path.abspath('./twitter_bigrams.bin')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generating sentence embeddings\n",
    "\n",
    "Now you can just run the following cells:\n",
    "\n",
    "## Utils for tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def tokenize(tknzr, sentence, to_lower=True):\n",
    "    \"\"\"Arguments:\n",
    "        - tknzr: a tokenizer implementing the NLTK tokenizer interface\n",
    "        - sentence: a string to be tokenized\n",
    "        - to_lower: lowercasing or not\n",
    "    \"\"\"\n",
    "    sentence = sentence.strip()\n",
    "    sentence = ' '.join([format_token(x) for x in tknzr.tokenize(sentence)])\n",
    "    if to_lower:\n",
    "        sentence = sentence.lower()\n",
    "    sentence = re.sub('((www\\.[^\\s]+)|(https?://[^\\s]+)|(http?://[^\\s]+))','<url>',sentence) #replace urls by <url>\n",
    "    sentence = re.sub('(\\@[^\\s]+)','<user>',sentence) #replace @user268 by <user>\n",
    "    filter(lambda word: ' ' not in word, sentence)\n",
    "    return sentence\n",
    "\n",
    "def format_token(token):\n",
    "    \"\"\"\"\"\"\n",
    "    if token == '-LRB-':\n",
    "        token = '('\n",
    "    elif token == '-RRB-':\n",
    "        token = ')'\n",
    "    elif token == '-RSB-':\n",
    "        token = ']'\n",
    "    elif token == '-LSB-':\n",
    "        token = '['\n",
    "    elif token == '-LCB-':\n",
    "        token = '{'\n",
    "    elif token == '-RCB-':\n",
    "        token = '}'\n",
    "    return token\n",
    "\n",
    "def tokenize_sentences(tknzr, sentences, to_lower=True):\n",
    "    \"\"\"Arguments:\n",
    "        - tknzr: a tokenizer implementing the NLTK tokenizer interface\n",
    "        - sentences: a list of sentences\n",
    "        - to_lower: lowercasing or not\n",
    "    \"\"\"\n",
    "    return [tokenize(tknzr, s, to_lower) for s in sentences]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utils for inferring embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_embeddings_for_preprocessed_sentences(sentences, model_path, fasttext_exec_path):\n",
    "    \"\"\"Arguments:\n",
    "        - sentences: a list of preprocessed sentences\n",
    "        - model_path: a path to the sent2vec .bin model\n",
    "        - fasttext_exec_path: a path to the fasttext executable\n",
    "    \"\"\"\n",
    "    timestamp = str(time.time())\n",
    "    test_path = os.path.abspath('./'+timestamp+'_fasttext.test.txt')\n",
    "    embeddings_path = os.path.abspath('./'+timestamp+'_fasttext.embeddings.txt')\n",
    "    dump_text_to_disk(test_path, sentences)\n",
    "    call(fasttext_exec_path+\n",
    "          ' print-sentence-vectors '+\n",
    "          model_path + ' < '+\n",
    "          test_path + ' > ' +\n",
    "          embeddings_path, shell=True)\n",
    "    embeddings = read_embeddings(embeddings_path)\n",
    "    os.remove(test_path)\n",
    "    os.remove(embeddings_path)\n",
    "    assert(len(sentences) == len(embeddings))\n",
    "    return np.array(embeddings)\n",
    "\n",
    "def read_embeddings(embeddings_path):\n",
    "    \"\"\"Arguments:\n",
    "        - embeddings_path: path to the embeddings\n",
    "    \"\"\"\n",
    "    with open(embeddings_path, 'r') as in_stream:\n",
    "        embeddings = []\n",
    "        for line in in_stream:\n",
    "            line = '['+line.replace(' ',',')+']'\n",
    "            embeddings.append(eval(line))\n",
    "        return embeddings\n",
    "    return []\n",
    "\n",
    "def dump_text_to_disk(file_path, X, Y=None):\n",
    "    \"\"\"Arguments:\n",
    "        - file_path: where to dump the data\n",
    "        - X: list of sentences to dump\n",
    "        - Y: labels, if any\n",
    "    \"\"\"\n",
    "    with open(file_path, 'w') as out_stream:\n",
    "        if Y is not None:\n",
    "            for x, y in zip(X, Y):\n",
    "                out_stream.write('__label__'+str(y)+' '+x+' \\n')\n",
    "        else:\n",
    "            for x in X:\n",
    "                out_stream.write(x+' \\n')\n",
    "\n",
    "def get_sentence_embeddings(sentences, ngram='bigrams', model='concat_wiki_twitter'):\n",
    "    \"\"\" Returns a numpy matrix of embeddings for one of the published models. It\n",
    "    handles tokenization and can be given raw sentences.\n",
    "    Arguments:\n",
    "        - ngram: 'unigrams' or 'bigrams'\n",
    "        - model: 'wiki', 'twitter', or 'concat_wiki_twitter'\n",
    "        - sentences: a list of raw sentences ['Once upon a time', 'This is another sentence.', ...]\n",
    "    \"\"\"\n",
    "    wiki_embeddings = None\n",
    "    twitter_embbedings = None\n",
    "    tokenized_sentences_NLTK_tweets = None\n",
    "    tokenized_sentences_SNLP = None\n",
    "    if model == \"wiki\" or model == 'concat_wiki_twitter':\n",
    "        tknzr = StanfordTokenizer(SNLP_TAGGER_JAR, encoding='utf-8')\n",
    "        s = ' <delimiter> '.join(sentences) #just a trick to make things faster\n",
    "        tokenized_sentences_SNLP = tokenize_sentences(tknzr, [s])\n",
    "        tokenized_sentences_SNLP = tokenized_sentences_SNLP[0].split(' <delimiter> ')\n",
    "        assert(len(tokenized_sentences_SNLP) == len(sentences))\n",
    "        if ngram == 'unigrams':\n",
    "            wiki_embeddings = get_embeddings_for_preprocessed_sentences(tokenized_sentences_SNLP, \\\n",
    "                                     MODEL_WIKI_UNIGRAMS, FASTTEXT_EXEC_PATH)\n",
    "        else:\n",
    "            wiki_embeddings = get_embeddings_for_preprocessed_sentences(tokenized_sentences_SNLP, \\\n",
    "                                     MODEL_WIKI_BIGRAMS, FASTTEXT_EXEC_PATH)\n",
    "    if model == \"twitter\" or model == 'concat_wiki_twitter':\n",
    "        tknzr = TweetTokenizer()\n",
    "        tokenized_sentences_NLTK_tweets = tokenize_sentences(tknzr, sentences)\n",
    "        if ngram == 'unigrams':\n",
    "            twitter_embbedings = get_embeddings_for_preprocessed_sentences(tokenized_sentences_NLTK_tweets, \\\n",
    "                                     MODEL_TWITTER_UNIGRAMS, FASTTEXT_EXEC_PATH)\n",
    "        else:\n",
    "            twitter_embbedings = get_embeddings_for_preprocessed_sentences(tokenized_sentences_NLTK_tweets, \\\n",
    "                                     MODEL_TWITTER_BIGRAMS, FASTTEXT_EXEC_PATH)\n",
    "    if model == \"twitter\":\n",
    "        return twitter_embbedings\n",
    "    elif model == \"wiki\":\n",
    "        return wiki_embeddings\n",
    "    elif model == \"concat_wiki_twitter\":\n",
    "        return np.concatenate((wiki_embeddings, twitter_embbedings), axis=1)\n",
    "    sys.exit(-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usecase\n",
    "\n",
    "To get embeddings you can now use the `get_sentence_embeddings` function, the paremeters are:\n",
    "- sentences: a list of unprocessed sentences\n",
    "- ngram: either `bigrams` or `unigrams`\n",
    "- model: `wiki`, `twitter` or `concat_wiki_twitter`\n",
    "\n",
    "Loading the models can take some time, but once loaded the inferrence is fast. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences = ['Once upon a time.', 'And now for something completely different.']\n",
    "\n",
    "my_embeddings = get_sentence_embeddings(sentences, ngram='unigrams', model='twitter')\n",
    "print(my_embeddings.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "et voila :)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
